# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import logging
import os
from fairseq.data.multi_corpus_dataset import MultiCorpusDataset
import torch
import json

from argparse import Namespace
from dataclasses import dataclass, field
from typing import Optional, Any, OrderedDict

from fairseq.data import AddTargetDataset, Dictionary, encoders
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.configs import GenerationConfig
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel

from . import register_task
from .. import utils
from ..logging import metrics


logger = logging.getLogger(__name__)


class LabelEncoder(object):
    def __init__(self, dictionary):
        self.dictionary = dictionary

    def __call__(self, label):
        return self.dictionary.encode_line(
            label, append_eos=False, add_if_not_exist=False
        )


def label_len_fn(label):
    return len(label.split(" "))


@dataclass
class AudioFinetuningConfig(AudioPretrainingConfig):
    # Options for reporting WER metrics during validation. Only applicable to
    # Seq2Seq models during fine-tuning
    eval_wer: bool = field(
        default=False, metadata={"help": "compute WER for Seq2Seq models"}
    )
    eval_wer_config: GenerationConfig = field(
        default_factory=lambda: GenerationConfig(),
        metadata={"help": "beam search config for evaluating wer during training"},
    )
    eval_wer_tokenizer: Any = field(
        default=None,
        metadata={"help": "tokenizer config for evaluating wer during training"},
    )
    eval_wer_post_process: str = field(
        default="letter",
        metadata={
            "help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)"
        },
    )
    eval_bleu: bool = field(
        default=False, metadata={"help": "evaluation with BLEU scores"}
    )
    eval_bleu_detok: Optional[str] = field(
        default=None,
        metadata={
            "help": "detokenize before computing BLEU (e.g., 'moses'); "
            "required if using --eval-bleu; use 'space' to disable "
            "detokenization; see fairseq.data.encoders for other options"
        },
    )
    eval_bleu_detok_args: str = field(
        default="{}", metadata={"help": "args for building the tokenizer, if needed"}
    )
    eval_tokenized_bleu: bool = field(
        default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
    )
    eval_bleu_remove_bpe: Optional[str] = field(
        default=None, metadata={"help": "remove BPE before computing BLEU"}
    )
    eval_bleu_args: str = field(
        default="{}",
        metadata={
            "help": "generation args for BLUE scoring, e.g., "
            '\'{"beam": 4, "lenpen": 0.6}\''
        },
    )
    eval_bleu_print_samples: bool = field(
        default=False, metadata={"help": "print sample generations during validation"}
    )
    autoregressive: bool = field(
        default=False,
        metadata={
            "help": "required for autoregressive decoders (like seq2seq models); "
            "adds 'prev_output_tokens' to input and appends eos to target"
        },
    )
    rebuild_batches: bool = True
    target_dictionary: Optional[str] = field(
        default=None,
        metadata={
            "help": "override default dictionary location"
        }
    )

@register_task("audio_finetuning", dataclass=AudioFinetuningConfig)
class AudioFinetuningTask(AudioPretrainingTask):
    """ """

    cfg: AudioFinetuningConfig

    def __init__(
        self,
        cfg: AudioFinetuningConfig,
    ):
        super().__init__(cfg)
        self.blank_symbol = "<s>"

        self.state.add_factory("target_dictionary", self.load_target_dictionary)

    def load_target_dictionary(self):
        if self.cfg.labels:
            target_dictionary = self.cfg.data
            if self.cfg.target_dictionary:  # override dict
                target_dictionary = self.cfg.target_dictionary
            dict_path = os.path.join(target_dictionary, f"dict.{self.cfg.labels}.txt")
            logger.info('Using dict_path : {}'.format(dict_path))
            return Dictionary.load(dict_path)
        return None

    def load_dataset(
        self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs
    ):
        super().load_dataset(split, task_cfg, **kwargs)

        task_cfg = task_cfg or self.cfg
        assert task_cfg.labels is not None
        text_compression_level = getattr(
            TextCompressionLevel, str(self.cfg.text_compression_level)
        )
        data_path = self.cfg.data
        if task_cfg.multi_corpus_keys is None:
            label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
            skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
            text_compressor = TextCompressor(level=text_compression_level)
            with open(label_path, "r") as f:
                labels = [
                    text_compressor.compress(l)
                    for i, l in enumerate(f)
                    if i not in skipped_indices
                ]

            assert len(labels) == len(self.datasets[split]), (
                f"labels length ({len(labels)}) and dataset length "
                f"({len(self.datasets[split])}) do not match"
            )

            process_label = LabelEncoder(self.target_dictionary)

            self.datasets[split] = AddTargetDataset(
                self.datasets[split],
                labels,
                pad=self.target_dictionary.pad(),
                eos=self.target_dictionary.eos(),
                batch_targets=True,
                process_label=process_label,
                label_len_fn=label_len_fn,
                add_to_input=task_cfg.get("autoregressive", False),
                text_compression_level=text_compression_level,
            )
        else:

            target_dataset_map = OrderedDict()

            multi_corpus_keys = [k.strip() for k in task_cfg.multi_corpus_keys.split(",")]
            corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)}

            data_keys = [k.split(":") for k in split.split(",")]

            multi_corpus_sampling_weights = [float(val.strip()) for val in task_cfg.multi_corpus_sampling_weights.split(",")]
            data_weights = []
            for key, file_name in data_keys:
                k = key.strip()
                label_path = os.path.join(data_path, f"{file_name.strip()}.{task_cfg.labels}")
                skipped_indices = getattr(self.dataset_map[split][k], "skipped_indices", set())
                text_compressor = TextCompressor(level=text_compression_level)
                with open(label_path, "r") as f:
                    labels = [
                        text_compressor.compress(l)
                        for i, l in enumerate(f)
                        if i not in skipped_indices
                    ]

                assert len(labels) == len(self.dataset_map[split][k]), (
                    f"labels length ({len(labels)}) and dataset length "
                    f"({len(self.dataset_map[split][k])}) do not match"
                )

                process_label = LabelEncoder(self.target_dictionary)

                # TODO: Remove duplication of code from the if block above
                target_dataset_map[k] = AddTargetDataset(
                    self.dataset_map[split][k],
                    labels,
                    pad=self.target_dictionary.pad(),
                    eos=self.target_dictionary.eos(),
                    batch_targets=True,
                    process_label=process_label,
                    label_len_fn=label_len_fn,
                    add_to_input=task_cfg.get("autoregressive", False),
                    text_compression_level=text_compression_level,
                )

                data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]])

            if len(target_dataset_map) == 1:
                self.datasets[split] = list(target_dataset_map.values())[0]
            else:
                self.datasets[split] = MultiCorpusDataset(target_dataset_map, distribution=data_weights, seed=0, sort_indices=True)

    @property
    def target_dictionary(self):
        """Return the :class:`~fairseq.data.Dictionary` for the language
        model."""
        return self.state.target_dictionary

    def valid_step(self, sample, model, criterion):
        loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
        if self.cfg.eval_wer and self.cfg.autoregressive:
            metrics = self._inference_with_wer(self.sequence_generator, sample, model)
            logging_output["_num_char_errors"] = metrics["num_char_errors"]
            logging_output["_num_chars"] = metrics["num_chars"]
            logging_output["_num_word_errors"] = metrics["num_word_errors"]
            logging_output["_num_words"] = metrics["num_words"]
        if self.cfg.eval_bleu and self.cfg.autoregressive:
            metrics = self._inference_with_bleu(self.sequence_generator, sample, model)
            logging_output["_bleu_sys_len"] = metrics.sys_len
            logging_output["_bleu_ref_len"] = metrics.ref_len
            # we split counts into separate entries so that they can be
            # summed efficiently across workers using fast-stat-sync
            assert len(metrics.counts) == 4
            for i in range(4):
                logging_output[f"_bleu_counts_{i}"] = metrics.counts[i]
                logging_output[f"_bleu_totals_{i}"] = metrics.totals[i]
        return loss, sample_size, logging_output

    def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
        model = super().build_model(model_cfg, from_checkpoint)

        if self.cfg.eval_wer and self.cfg.autoregressive:
            self.sequence_generator = self.build_generator(
                [model],
                self.cfg.eval_wer_config,
            )
            if self.cfg.eval_wer_tokenizer:
                self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer)
            else:
                self.tokenizer = None
        if self.cfg.eval_bleu and self.cfg.autoregressive:
            assert self.cfg.eval_bleu_detok is not None, (
                "--eval-bleu-detok is required if using --eval-bleu; "
                "try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
                "to disable detokenization, e.g., when using sentencepiece)"
            )
            detok_args = json.loads(self.cfg.eval_bleu_detok_args)
            self.tokenizer = encoders.build_tokenizer(
                Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args)
            )
            gen_args = json.loads(self.cfg.eval_bleu_args)
            gen_args = Namespace(**gen_args)
            self.sequence_generator = self.build_generator([model], gen_args)

        return model

    def _inference_with_wer(self, generator, sample, model):
        import editdistance

        def decode(toks):
            s = self.target_dictionary.string(
                toks.int().cpu(),
                self.cfg.eval_wer_post_process,
                escape_unk=True,
            )
            if self.tokenizer:
                s = self.tokenizer.decode(s)
            return s

        num_word_errors, num_char_errors = 0, 0
        num_chars, num_words = 0, 0
        gen_out = self.inference_step(generator, [model], sample, None)
        for i in range(len(gen_out)):
            hyp = decode(gen_out[i][0]["tokens"])
            ref = decode(
                utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
            )
            num_char_errors += editdistance.eval(hyp, ref)
            num_chars += len(ref)
            hyp_words = hyp.split()
            ref_words = ref.split()
            num_word_errors += editdistance.eval(hyp_words, ref_words)
            num_words += len(ref_words)

        return {
            "num_char_errors": num_char_errors,
            "num_chars": num_chars,
            "num_word_errors": num_word_errors,
            "num_words": num_words,
        }

    def _inference_with_bleu(self, generator, sample, model):
        import sacrebleu

        def decode(toks, is_ref):
            s = self.target_dictionary.string(
                toks.int().cpu(),
                self.cfg.eval_bleu_remove_bpe,
                # The default unknown string in fairseq is `<unk>`, but
                # this is tokenized by sacrebleu as `< unk >`, inflating
                # BLEU scores. Instead, we use a somewhat more verbose
                # alternative that is unlikely to appear in the real
                # reference, but doesn't get split into multiple tokens.
                unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"),
            )
            if self.tokenizer:
                s = self.tokenizer.decode(s)
            return s

        gen_out = self.inference_step(generator, [model], sample)
        hyps, refs = [], []
        for i in range(len(gen_out)):
            hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False))
            refs.append(
                decode(
                    utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
                    is_ref=True,  # don't count <unk> as matches to the hypo
                )
            )
        if self.cfg.eval_bleu_print_samples:
            logger.info("H-{} {}".format(sample["id"][0], hyps[0]))
            logger.info("T-{} {}".format(sample["id"][0], refs[0]))

        eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a"
        return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization)

    def reduce_metrics(self, logging_outputs, criterion):
        super().reduce_metrics(logging_outputs, criterion)

        if self.cfg.eval_wer:
            zero = torch.scalar_tensor(0.0)
            num_char_errors = sum(
                log.get("_num_char_errors", zero) for log in logging_outputs
            )
            num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs)
            num_word_errors = sum(
                log.get("_num_word_errors", zero) for log in logging_outputs
            )
            num_words = sum(log.get("_num_words", zero) for log in logging_outputs)
            metrics.log_scalar("_num_char_errors", num_char_errors)
            metrics.log_scalar("_num_chars", num_chars)
            metrics.log_scalar("_num_word_errors", num_word_errors)
            metrics.log_scalar("_num_words", num_words)
            if num_chars > 0:
                metrics.log_derived(
                    "uer",
                    lambda meters: meters["_num_char_errors"].sum
                    * 100.0
                    / meters["_num_chars"].sum
                    if meters["_num_chars"].sum > 0
                    else float("nan"),
                )
            if num_words > 0:
                metrics.log_derived(
                    "wer",
                    lambda meters: meters["_num_word_errors"].sum
                    * 100.0
                    / meters["_num_words"].sum
                    if meters["_num_words"].sum > 0
                    else float("nan"),
                )
        if self.cfg.eval_bleu:
            len_keys = ["_bleu_sys_len", "_bleu_ref_len"]
            count_keys = [f"_bleu_counts_{i}" for i in range(4)]
            total_keys = [f"_bleu_totals_{i}" for i in range(4)]
            for k in len_keys + count_keys + total_keys:
                metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs))

            import sacrebleu

            metrics.log_derived(
                "bleu",
                lambda meters: sacrebleu.compute_bleu(
                    correct=[meters[k].sum for k in count_keys],
                    total=[meters[k].sum for k in total_keys],
                    sys_len=meters["_bleu_sys_len"].sum,
                    ref_len=meters["_bleu_ref_len"].sum,
                    smooth_method="exp",
                ).score,
            )
